{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1.对层进行处理\n",
    "\n",
    "现在，我们得到了一个 LeNet，实例化之后打印，结果如下所示：\n",
    "\n",
    "```\n",
    "LeNet5(\n",
    "  (conv1): Conv2d(224, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
    "  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
    "  (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
    "  (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
    "  (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
    ")\n",
    "```\n",
    "\n",
    "然后，我们现在需要把所有的带有 \"conv\" 的层全部删除（替换）我们该怎么做呢？\n",
    "\n",
    "我们需要重新写一个模型嘛？\n",
    "\n",
    "NOOOOO！！！\n",
    "\n",
    "不需要， Pytorch 的层是可以随意替换的！！！我们随时可以增加，修改。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class LeNet5(nn.Module):\n",
    "    def __init__(self, in_dim, n_class):\n",
    "        super(LeNet5, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_dim, 6, 5, padding=2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16*5*5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, n_class)\n",
    "        \n",
    "        # 参数初始化函数\n",
    "        for p in self.modules():\n",
    "            if isinstance(p, nn.Conv2d):\n",
    "                nn.init.xavier_normal(p.weight.data)\n",
    "            elif isinstance(p, nn.Linear):\n",
    "                nn.init.normal(p.weight.data)\n",
    "\n",
    "    # 向前传播\n",
    "    def forward(self, x):\n",
    "        x = F.max_pool2d(F.relu(self.conv1(x)), 2)\n",
    "        x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n",
    "        x = x.view(-1, self.num_flat_features(x))\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "    \n",
    "    def num_flat_features(self, x):\n",
    "        size = x.size()[1:]\n",
    "        num_features = 1\n",
    "        for s in size:\n",
    "            num_features *= s\n",
    "        return num_features\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "# 实例化 LeNet\n",
    "\n",
    "lenet = LeNet5(224, 10)\n",
    "\n",
    "del_list = [] # 将所有需要删除的层的名字放进这列表里面\n",
    "for name, module in lenet.named_children():\n",
    "    # 对所有的层的名字进行匹配，如果层中含有conv则把名字放入等待删除列表中\n",
    "    if re.match(\"conv\", name) != None:\n",
    "        del_list.append(name)\n",
    "\n",
    "# 获取了所有带 \"conv\" 层的名字了，现在开始删除\n",
    "for name in del_list:\n",
    "    delattr(lenet, name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "print(lenet) 得到的结果如下所示\n",
    "\n",
    "```\n",
    "LeNet5(|\n",
    "  (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
    "  (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
    "  (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2.改变预训练模型中的某 sequential 层里面的某一层\n",
    "\n",
    "这里我们以Alexnet为例\n",
    "\n",
    "```\n",
    "AlexNet(\n",
    "  (features): Sequential(\n",
    "    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
    "    (1): ReLU(inplace)\n",
    "    (2): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=False)\n",
    "    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
    "    (4): ReLU(inplace)\n",
    "    (5): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=False)\n",
    "    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
    "    (7): ReLU(inplace)\n",
    "    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
    "    (9): ReLU(inplace)\n",
    "    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
    "    (11): ReLU(inplace)\n",
    "    (12): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=False)\n",
    "  )\n",
    "  (classifier): Sequential(\n",
    "    (0): Dropout(p=0.5)\n",
    "    (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
    "    (2): ReLU(inplace)\n",
    "    (3): Dropout(p=0.5)\n",
    "    (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
    "    (5): ReLU(inplace)\n",
    "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
    "  )\n",
    ")\n",
    "```\n",
    "\n",
    "某一天，我们突然发现了 AvgPool 的效果比 Maxpool 要好，于是我们想要改变把 AlextNet中的所有 Maxpool 用 AvgPool来代替该怎么办呢？\n",
    "\n",
    "> 查看 Alexnet 模型中的 Maxpool 的位置，我们可以清晰的发现 Maxpool 处于 features 下面的 (2)、(5)、(12)的位置上\n",
    "\n",
    "\n",
    "于是，我们改变 features 层下面的 (2)、(5)、(12) 就行了，具体操作如下所示\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import models, transforms, datasets\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "alexnet = models.alexnet(pretrained=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AlexNet(\n",
      "  (features): Sequential(\n",
      "    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
      "    (1): ReLU(inplace)\n",
      "    (2): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=False)\n",
      "    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "    (4): ReLU(inplace)\n",
      "    (5): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=False)\n",
      "    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (7): ReLU(inplace)\n",
      "    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (9): ReLU(inplace)\n",
      "    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (11): ReLU(inplace)\n",
      "    (12): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=False)\n",
      "  )\n",
      "  (classifier): Sequential(\n",
      "    (0): Dropout(p=0.5)\n",
      "    (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
      "    (2): ReLU(inplace)\n",
      "    (3): Dropout(p=0.5)\n",
      "    (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
      "    (5): ReLU(inplace)\n",
      "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# 改变前的 alexnet\n",
    "print(alexnet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 开始改变\n",
    "\n",
    "alexnet.features._modules['2'] = nn.AvgPool2d(kernel_size=(3, 3), stride=(2, 2), ceil_mode=False)\n",
    "alexnet.features._modules['5'] = nn.AvgPool2d(kernel_size=(3, 3), stride=(2, 2), ceil_mode=False)\n",
    "alexnet.features._modules['12'] = nn.AvgPool2d(kernel_size=(3, 3), stride=(2, 2), ceil_mode=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AlexNet(\n",
      "  (features): Sequential(\n",
      "    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
      "    (1): ReLU(inplace)\n",
      "    (2): AvgPool2d(kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False, count_include_pad=True)\n",
      "    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "    (4): ReLU(inplace)\n",
      "    (5): AvgPool2d(kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False, count_include_pad=True)\n",
      "    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (7): ReLU(inplace)\n",
      "    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (9): ReLU(inplace)\n",
      "    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (11): ReLU(inplace)\n",
      "    (12): AvgPool2d(kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False, count_include_pad=True)\n",
      "  )\n",
      "  (classifier): Sequential(\n",
      "    (0): Dropout(p=0.5)\n",
      "    (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
      "    (2): ReLU(inplace)\n",
      "    (3): Dropout(p=0.5)\n",
      "    (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
      "    (5): ReLU(inplace)\n",
      "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    " # 改变后的 alexnet\n",
    "print(alexnet)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
